rm(list=ls())
library(dplyr)
library(stringr)
library(combinat)
library(reshape2)
library(foreach)
library(glue)
library(doParallel)

list2env(readRDS('code/bs_true_values.RDS'), globalenv())

true_AME <- true_AME %>% ungroup %>%
  transmute(group = cluster, truth = mean,
            level = paste0('f', factor, '_', letters[level]),
            factor = paste0('X', factor))

true_flat_beta <- bind_rows(lapply(true_beta, FUN=function(i){
  colnames(i) <- letters[1:NL]
  rownames(i) <- 1:NJ
  return(reshape2::melt(i))
}), .id = 'group')
true_flat_beta <- true_flat_beta %>% transmute(group = group, truth = value,
                                               factor = paste0('X', Var1), level = paste0('f', Var1, '_', Var2))
true_flat_beta <- true_flat_beta %>% mutate(group = as.numeric(group))

tmpdir <- tempdir()

# Get the between, within, and total sum of squares of the
# estimated HTE/CAMCE
heterogeneity_HTE <- function(HTE, ppred, regex){
  regex <- '^X[0-9]+$|^group_[0-9]+$'
  
  # Merge in cluster membership probabilities
  aug_HTE <- left_join(HTE, ppred, by = 'group')
  # Get the weighted within-cluster centroid 
  group_centroid <- aug_HTE %>% 
    group_by(factor, level) %>% 
    summarize(across(matches(regex), ~ weighted.mean(est, .)),
              .groups = 'drop')
  group_size <- aug_HTE %>% 
    group_by(factor, level) %>% 
    summarize(across(matches(regex), sum), .groups = 'drop')
  # Flatten to get one row per person/cluster/CAMCE
  # with the probability of being in that group
  long_HTE <- aug_HTE %>% 
    pivot_longer(cols = matches(regex)) %>% 
    rename(person = group) %>%
    rename(group = name, prob = value)
  # Get the overall estimate, averaging across groups
  overall_centroid <- long_HTE %>%
    group_by(factor, level) %>%
    summarize(est = weighted.mean(est, prob), .groups = 'drop')
  
  long_group_centroid <- group_centroid %>% 
    pivot_longer(cols = matches(regex)) %>%
    rename(group = name, est_group = value)
  
  long_HTE <- left_join(
    long_HTE, 
    long_group_centroid, 
    by = c('group', 'factor', 'level'))
  long_HTE <- left_join(
    long_HTE,
    overall_centroid %>% rename(overall_est = est),
    by = c('factor', 'level')
  )
  within_sos <- long_HTE %>% 
    group_by(group, factor, level) %>% 
    summarize(within_sos = sum( (est - est_group)^2 * prob ), .groups = 'drop') %>% 
    group_by(factor, level) %>% summarize(within_sos = sum(within_sos), .groups = 'drop')
  
  dta_out <- full_join(
    long_group_centroid,
    group_size %>%
      pivot_longer(cols = matches(regex)) %>%
      rename(group = name, size_group = value),
    by = c('group', 'factor', 'level')
  )
  dta_out <- full_join(
    dta_out,
    overall_centroid %>% rename(overall_est = est),
    by = c('factor', 'level')
  )
  
  between_sos <- dta_out %>% 
    group_by(factor, level) %>% 
    summarize(between_sos = sum( (est_group - overall_est)^2 * size_group),
              .groups = 'drop')
  
  total_sos <- long_HTE %>% 
    group_by(factor, level) %>% 
    summarize(total_sos = sum( (est -   overall_est)^2 * prob ), .groups = 'drop')
  # Combine together
  all_sos <- left_join(
    left_join(total_sos, between_sos, by = c("factor", "level")), 
    within_sos, by = c('factor', 'level'))
  all_sos <- all_sos %>% mutate(
    checksum_sos = total_sos - (between_sos + within_sos))
  return(all_sos)
}

file_list <- sort(as.numeric(str_extract(dir('AOAS_simulation'), pattern='[0-9]+')))
if (length(file_list) != 500){
  warning('Missing some of the simulations', immediate. = TRUE)
  print(setdiff(0:499, file_list))
}

NCORES <- 10
cl <- makeCluster(NCORES)
registerDoParallel(cl)

all_out <- foreach(slurm_id = file_list, .export = c('tmpdir'),
                   .packages = c('glue', 'Matrix', 'tidyr', 'dplyr', 'stringr', 'foreach', 'FactorHet')) %dopar% {
  
  print(slurm_id)
  system(glue('mkdir {tmpdir}/out_{slurm_id}'))
  system(glue("tar -xzf AOAS_simulation/out_{slurm_id}.tar.gz -C {tmpdir}/out_{slurm_id}"))
  path <- glue('{tmpdir}/out_{slurm_id}/out')
  
  
  files_to_load <- dir(path, pattern='RDS', full.names = T)
  
  grid_hte <- seq(-1, 1, by = 0.01)
  
  loop_parse <- foreach(v = files_to_load, .packages = c('Matrix', 'tidyr', 'dplyr', 'FactorHet')) %do% {
    all_posterior <- all_AME <- all_cjoint <- data.frame()
    
    bs_v <- readRDS(v)
    
    true_v_HTE <- bs_v$true_HTE %>%
      transmute(group = as.character(id), truth, factor, level)
    
    ignore_output <- c('seed', 'true_cluster', 'true_HTE', 
                       'cjbart', 'wrong_K', 'misspec', 'posterior', 'OOS_accuracy')
    misspec_names <- paste0('misspec@',
                            unlist(mapply(names(bs_v$misspec), bs_v$misspec, SIMPLIFY=FALSE, 
                                          FUN=function(i,j){paste0(i,'@',names(j))}))
    )
    
    for (type in c(setdiff(names(bs_v), ignore_output), misspec_names)){
      
      if (grepl(type, pattern='misspec@')){
        split_type <- strsplit(type, split='@')[[1]]
        bs_v_object <- bs_v$misspec[[split_type[2]]][[split_type[3]]]
      }else{
        bs_v_object <- bs_v[[type]]
      }
      compare_posterior <- left_join(bs_v_object[['posterior']] %>% mutate(group = as.character(group)), 
                                     data.frame(group = as.character(1:length(bs_v$true_cluster)), truth = bs_v$true_cluster, stringsAsFactors = F),
                                     by = 'group')
      
      if (!(type %in% c('half', 'ssplit', misspec_names[grep(misspec_names, pattern='half|ssplit')]))){
        stopifnot(nrow(compare_posterior) == length(bs_v$true_cluster))
      }
      
      all_perm <- combinat::permn(1:NCLUSTER)
      
      compare_posterior$group <- 1:nrow(compare_posterior)
      true_assignment <- as.matrix(Matrix::sparseMatrix(
        i = 1:nrow(bs_v_object$posterior), 
        j = compare_posterior$truth, 
        x = 1, dims = c(nrow(bs_v_object$posterior), NCLUSTER)))
      
      best_perm <- FactorHet:::internal_align(
        true_assignment[as.numeric(compare_posterior$group),], 
        as.matrix(compare_posterior[,paste0('group_', 1:NCLUSTER)])
      )
      perm_error <- sapply(all_perm, FUN=function(i){
        mean(abs(true_assignment[as.numeric(compare_posterior$group),] - as.matrix(compare_posterior[,paste0('group_', 1:NCLUSTER)])[,i]))
      })
      
      AME_v <- bs_v_object[['AME']]
      if (inherits(AME_v, 'FactorHet_vis')){
        AME_v <- AME_v$data
      }
      AME_v$group <- match(AME_v$group, best_perm)
      AME_v <- full_join(AME_v, true_AME, by = c('group', 'factor', 'level'))
      
      all_AME <- bind_rows(all_AME, 
                           AME_v %>% filter(!baseline) %>% 
                             mutate(sim = v, type = type, perm_error = min(perm_error))
      )
      
      cjoint_v <- bs_v_object[["cjoint"]]
      cjoint_v$group <- match(cjoint_v$group, best_perm)
      cjoint_v <- full_join(cjoint_v, true_flat_beta, by = c('group', 'factor', 'level'))
      all_cjoint <- bind_rows(all_cjoint, cjoint_v %>% mutate(sim = v, type = type))
      
      posterior_in_truth <- rowSums(
        true_assignment[as.numeric(compare_posterior$group),] * 
          as.matrix(compare_posterior[,paste0('group_', 1:NCLUSTER)])[,best_perm]
      )
      
      postpred_in_truth <- left_join(bs_v_object[['posterior_predictive']] %>% mutate(group = as.character(group)), 
                                     data.frame(group = as.character(1:length(bs_v$true_cluster)), truth = bs_v$true_cluster, stringsAsFactors = F),
                                     by = 'group')
      postpred_in_truth <- rowSums(
        true_assignment * 
          as.matrix(postpred_in_truth[,paste0('group_', 1:NCLUSTER)])[,best_perm]
      )
      
      all_posterior <- bind_rows(all_posterior,
                                 data.frame(
                                   posterior = mean(posterior_in_truth),
                                   postpred = mean(postpred_in_truth),
                                   median_posterior = median(posterior_in_truth),
                                   median_postpred = median(postpred_in_truth),
                                   cor_truth = cor(posterior_in_truth, postpred_in_truth)
                                 ) %>% mutate(sim = v, type = type, perm_error = min(perm_error))
      )
      
      stopifnot(all(AME_v %>% filter(baseline) %>% pull(truth) == 0))
      stopifnot(!any(is.na(AME_v$truth)))
    }
    rownames(all_posterior) <- NULL
    
    hte_names <- setdiff(names(bs_v), ignore_output)
    all_HTE <- lapply(bs_v[hte_names], `[[`, "HTE")
    all_HTE_misspec <- lapply(bs_v$misspec[names(bs_v$misspec) != 'no_mod'], FUN=function(i){lapply(i, `[[`, 'HTE')})
    all_HTE <- c(all_HTE, unlist(all_HTE_misspec, recursive = FALSE))
    all_HTE <- c(all_HTE, list(cjbart = list(individual = bs_v$cjbart$HTE, population = NA)))
    all_HTE <- c(all_HTE, lapply(bs_v$wrong_K, FUN=function(i){
      if (is.null(i)){return(NULL)}
      full_HTE <- i$full
      split_HTE <- i$ssplit
      out <- list(full_HTE, split_HTE)
      names(out) <- paste0('wrong_', i$K, c('_full', '_split'))
      return(out)
    }) %>% unlist(., recursive = FALSE)
    )
    
    all_HTE <- lapply(all_HTE, FUN=function(i){
      if (!('ll' %in% names(i$individual))){
        i$individual$var[i$individual$var < 0] <- 0
        i$individual$ll <- i$individual$est - 1.96 * sqrt(i$individual$var)
        i$individual$ul <- i$individual$est + 1.96 * sqrt(i$individual$var)
      }
      pre_N <- nrow(i$individual)
      i <- left_join(i$individual %>% mutate(group = as.character(group)), 
                     true_v_HTE, 
                     by = c('group', 'factor', 'level'))
      stopifnot(nrow(i) == pre_N)
      out_avg <- i %>% ungroup %>% summarize(
        mae = mean(abs(est - truth)),
        rmse = sqrt(mean( (est - truth)^2 )),
        bias = mean(est - truth),
        var_diff = var(est - truth),
        coverage = mean( (truth > ll) & (truth < ul)))
      out_binscatter <- i %>% 
        mutate(group = cut(truth, grid_hte)) %>% 
        group_by(group) %>% 
        summarize(est = mean(est), n = n())
      out_avg_by_AME <- i %>% ungroup %>% 
        group_by(factor, level) %>%  
        summarize(
          avg_est = mean(est),
          var_est = var(est),
          avg_truth = mean(truth),
          var_truth = var(truth),
          var_diff = var(est - truth),
          .groups = 'drop'
        )
      
      return(list(out = out_avg, 
                  out_by_AME = out_avg_by_AME,
                  binscatter = out_binscatter))
    })
    
    
    summarize_HTE <- list(
      out = bind_rows(lapply(all_HTE, `[[`, "out"), .id = 'method') %>%
        mutate(sim = v),
      binscatter = bind_rows(lapply(all_HTE, `[[`, "binscatter"), .id = 'method') %>%
        mutate(sim = v),
      out_by_AME = bind_rows(lapply(all_HTE, `[[`, "out_by_AME"), .id = 'method') %>%
        mutate(sim = v)
    )
    
    sos_wrong_K <- lapply(bs_v$wrong_K[c(2, 4)], FUN=function(i){
      sos_full <- heterogeneity_HTE(
        true_v_HTE %>% rename(est = truth),
        i$full_HTE$posterior_predictive
      ) %>% mutate(type = paste0('wrong_', i$K, '_full'))
      sos_ssplit <- heterogeneity_HTE(
        true_v_HTE %>% rename(est = truth), 
        i$ssplit_HTE$posterior_predictive
      ) %>% mutate(type = paste0('wrong_', i$K, '_split'))
      return(rbind(sos_full, sos_ssplit))
    }) %>% bind_rows()
    
    sos_main <- lapply(hte_names, FUN=function(i){
      heterogeneity_HTE(
        true_v_HTE %>% rename(est = truth),
        bs_v[[i]]$HTE$posterior_predictive
      )
    }) 
    names(sos_main) <- hte_names
    sos_main <- sos_main %>% bind_rows(.id = 'type')
    
    sos_out <- rbind(sos_wrong_K, sos_main) %>% mutate(sim = v)
    
    sos_out %>% group_by(type) %>% summarize(across(matches('sos'), sum)) %>% mutate(rsq = between_sos/total_sos)
    all_OOS <- bs_v$OOS_accuracy %>% mutate(sim = v)
    
    return(list(all_AME = all_AME, all_sos = sos_out,
                all_posterior = all_posterior,
                all_OOS = all_OOS,
                all_cjoint = all_cjoint, all_HTE = summarize_HTE))
  }
  
  all_AME <- lapply(loop_parse, `[[`, "all_AME") %>% bind_rows()
  all_cjoint <- lapply(loop_parse, `[[`, "all_cjoint") %>% bind_rows()
  all_posterior <- lapply(loop_parse, `[[`, "all_posterior") %>% bind_rows()
  all_sos <- lapply(loop_parse, `[[`, "all_sos") %>% bind_rows()
  all_HTE <- lapply(loop_parse, FUN=function(i){
    i$all_HTE$out
  }) %>% bind_rows()
  all_HTE_binscatter <- lapply(loop_parse, FUN=function(i){
    i$all_HTE$binscatter
  }) %>% bind_rows()
  all_HTE_by_AME <- lapply(loop_parse, FUN=function(i){
    i$all_HTE$out_by_AME
  }) %>% bind_rows()
  all_OOS <- lapply(loop_parse, `[[`, "all_OOS") %>% bind_rows()
  
  
  all_AME <- all_AME %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_cjoint <- all_cjoint %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_posterior <- all_posterior %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_HTE <- all_HTE %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_sos <- all_sos %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_HTE_binscatter <- all_HTE_binscatter %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_HTE_by_AME <- all_HTE_by_AME %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  all_OOS <- all_OOS %>% mutate(sample_size = str_extract(sim, pattern='large|small|very_large'))
  
  all_cjoint <- all_cjoint %>% 
    mutate(neg_variance = var < 0) %>%
    mutate(se = ifelse(var < 0, 0, sqrt(var)))  %>%
    mutate(selected = abs(coef) > sqrt(.Machine$double.eps)) %>%
    mutate(coverage = (coef - 1.96 * se <= truth) & ( coef + 1.96 * se >= truth))
  all_AME <- all_AME %>%  
    mutate(neg_variance = var < 0) %>%
    mutate(se = ifelse(var < 0, 0, sqrt(var)))  %>%
    mutate(selected = abs(marginal_effect) > sqrt(.Machine$double.eps)) %>%
    mutate(coverage = (marginal_effect - 1.96 * se <= truth) & ( marginal_effect + 1.96 * se >= truth))
  
  print(path)
  
  out <- list(
    AME = all_AME, cjoint = all_cjoint, 
    posterior = all_posterior, 
    sos = all_sos,
    HTE = all_HTE, 
    OOS = all_OOS,
    HTE_binscatter = all_HTE_binscatter, 
    HTE_by_AME = all_HTE_by_AME)
  system(glue("rm -r {tmpdir}/out_{slurm_id}"))
  
  return(out)
}

stopCluster(cl); rm(cl)
stopImplicitCluster()

all_AME <- bind_rows(lapply(all_out, `[[`, "AME"), .id = 'file')
all_cjoint <- bind_rows(lapply(all_out, `[[`, "cjoint"), .id = 'file')
all_posterior <- bind_rows(lapply(all_out, `[[`, "posterior"), .id = 'file')
all_sos <- bind_rows(lapply(all_out, `[[`, "sos"), .id = 'file')
all_HTE <- bind_rows(lapply(all_out, `[[`, "HTE"), .id = 'file')
all_HTE_binscatter <- bind_rows(lapply(all_out, `[[`, "HTE_binscatter"), .id = 'file')
all_HTE_by_AME <- bind_rows(lapply(all_out, `[[`, "HTE_by_AME"), .id = 'file')
all_OOS <- bind_rows(lapply(all_out, `[[`, "OOS"), .id = 'file')

saveRDS(all_AME, 'final_output/final_AME.RDS')
saveRDS(all_cjoint, 'final_output/final_cjoint.RDS')
saveRDS(all_posterior, 'final_output/final_posterior.RDS')
saveRDS(all_sos, 'final_output/final_sos.RDS')
saveRDS(all_HTE, 'final_output/final_HTE.RDS')
saveRDS(all_HTE_binscatter, 'final_output/final_HTE_binscatter.RDS')
saveRDS(all_HTE_by_AME, 'final_output/final_HTE_by_AME.RDS')
saveRDS(all_OOS, 'final_output/final_OOS.RDS')
